#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################
# Some parts of the code are borrowed from: https://github.com/pytorch/vision
# with the following license:
#
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

import torch
from .utils import *
from ... import xnn
import numpy as np
from .mobilenetv2 import MobileNetV2TV, get_config as get_config_mnetv2

#'StudentTeacherLearnerNV12'
__all__ = ['MobileNetV2TVNV12', 'MobileNetV2TVGWS', 'MobileNetV2TVGWSBase', 'mobilenet_v2_tv_gws']


###################################################
class ColorMixerNV12(torch.nn.Module):
    def __init__(self, inp, intermediate, oup, stride):
        super().__init__()
        self.conv_y = xnn.layers.ConvNormAct2d(inp, intermediate, kernel_size=3, stride=stride)
        self.conv_uv = xnn.layers.ConvNormAct2d(inp, intermediate, kernel_size=3, stride=(1, stride))
        self.conv_mixer = xnn.layers.ConvNormAct2d(intermediate, oup, kernel_size=3)

    def forward(self, x):
        assert isinstance(x, (list,tuple)) and len(x)>=2, 'input must be a list/tuple of length 2'
        y = x[0]
        uv = x[1]
        y = self.conv_y(y)
        uv = self.conv_uv(uv)
        yuv = (y + uv)
        yuv = self.conv_mixer(yuv)
        return yuv

class MobileNetV2TVNV12(MobileNetV2TV):
    def __init__(self, **kwargs):
        model_config = get_config_mnetv2()
        if 'model_config' in list(kwargs.keys()):
            model_config = model_config.merge_from(kwargs['model_config'])
        super().__init__(model_config=model_config)
        s0 = model_config.strides[0]
        intermediate_channels = 6
        output_channels = xnn.utils.make_divisible_by8(self.model_config.layer_setting[0][1] * self.model_config.width_mult)
        color_mixer = ColorMixerNV12(1, intermediate_channels, output_channels, s0)
        self.features[0] = color_mixer


###################################################
def get_config_mnetv2_gws():
    model_config = xnn.utils.ConfigNode()
    model_config.input_channels = 3
    model_config.num_classes = 1000
    model_config.width_mult = 1.
    model_config.expand_ratio = 6
    model_config.output_stride = 32
    model_config.activation = xnn.layers.DefaultAct2d
    model_config.use_blocks = False
    model_config.kernel_size = 3
    model_config.dropout = False
    model_config.linear_dw = False
    model_config.layer_setting = None
    return model_config


###################################################
def define_nw(nw_type =  'mobv2_orig', s=None, model_config=None): 

    mobv2_orig = [
       #intermediate_c, c, n, s, group_size_dws, group_pw, group_lin, el_add      
       #group_size_dws: Group size for depth wise separable
       #group_pw: Group size for pw 1x1 layer
       #group_lin: Group size for lin 1x1 layer
       #el_add: true:use element wise addition if it is free, false:don't use
       [32, 32,   1, s[0], 1, 1, 1, 1],                 #blk_1
       [32, 16,   1, 1,    1, 1, 1, 1],                 #blk_2_1
       [96, 24,   1, s[1], 1, 1, 1, 1],                 #blk_3_1
       [144, 24,  1, 1,    1, 1, 1, 1],                 #blk_3_2
       [144, 32,  1, s[2], 1, 1, 1, 1],                 #blk_4_1
       [192, 32,  2, 1,    1, 1, 1, 1],                 #4_2, 4_3
       [192, 64,  1, s[3], 1, 1, 1, 1],                 #blk_5_1
       [384, 64,  3, 1,    1, 1, 1, 1],                 #5_2, 5_3, 5_4
       [384, 96,  1, 1,    1, 1, 1, 1],                 #blk_6_1
       [576, 96,  2, 1,    1, 1, 1, 1],                 #blk_6_2, blk_6_3
       [576, 160, 1, s[4], 1, 1, 1, 1],                 #blk_7_1
       [960, 160, 2, 1,    1, 1, 1, 1],                 #blk_7_2, blk_7_3
       [960, 320, 1, 1,    1, 1, 1, 1],                 #blk8_1
       [320, 1280,1, 1,    1, 1, 1, 1],
     ]

    mobv2_dws_gwsv2 = [
      #intermediate_c, c, n, s, group_size_dws, group_pw, group_lin, el_add      
      #group_size_dws: Group size for depth wise separable
      #group_pw: Group size for pw 1x1 layer
      #group_lin: Group size for lin 1x1 layer
      #el_add: tr ue:use element wise addition if it is free, false:don't use
      [32, 32,    1, s[0], 1, 1, 1, 1],                 #blk_1
      [32, 32,    1, 1,    4, 1, 1, 1],                 #blk_2_1
      [124, 63,   1, s[1], 4, 1, 1, 1],                 #blk_3_1
      [188, 63,   1, 1,    4, 1, 1, 1],                 #blk_3_2
      [188, 63,   3, s[2], 4, 1, 1, 1],                 #blk_4_1, 4_2, 4_3
      [188, 63,   1, s[3], 4, 1, 1, 1],                 #blk_5_1
      [380, 63,   3, 1,    4, 1, 1, 1],                 #5_2, 5_3, 5_4
      [380, 124,  1, 1,    4, 1, 2, 1],                 #blk_6_1
      [636, 124,  2, 1,    4, 2, 2, 1],                 #blk_6_2, blk_6_3
      [636, 252,  1, s[4], 4, 2, 4, 1],                 #blk_7_1
      [1020, 252, 2, 1,    4, 4, 4, 1],                 #blk_7_2, blk_7_3
      [1020, 320, 1, 1,    4, 4, 4, 1],                 #blk8_1
      [320, 1280, 1, 1,    1, 1, 1, 1],
    ]

    mobv2_dws_gwsv3 = [
      #intermediate_c, c, n, s, group_size_dws, group_pw, group_lin, el_add      
      #group_size_dws: Group size for depth wise separable
      #group_pw: Group size for pw 1x1 layer
      #group_lin: Group size for lin 1x1 layer
      #el_add: true:use element wise addition if it is free, false:don't use
      [32,   32,  1, s[0], 1, 1, 1, 1],                 #blk_1
      [32,   32,  1, 1,    4, 1, 1, 0],                 #blk_2_1
      [96,   63,  1, s[1], 4, 1, 1, 1],                 #blk_3_1
      [144,  63,  1, 1,    4, 1, 1, 1],                 #blk_3_2
      [188,  63,  3, s[2], 4, 1, 1, 1],                 #blk_4_1, 4_2, 4_3
      [188,  63,  1, s[3], 4, 1, 1, 1],                 #blk_5_1
      [380,  63,  3, 1,    4, 1, 1, 1],                 #5_2, 5_3, 5_4
      [380,  124, 1, 1,    4, 1, 2, 1],                 #blk_6_1
      [636,  124, 2, 1,    4, 2, 2, 1],                 #blk_6_2, blk_6_3
      [636,  252, 1, s[4], 4, 2, 4, 1],                 #blk_7_1
      [1020, 252, 2, 1,    4, 4, 4, 1],                 #blk_7_2, blk_7_3
      [1020, 508, 1, 1,    4, 4, 4, 1],                 #blk8_1
      [320, 1280, 1, 1,    1, 1, 1, 1],
    ]

    mobv2_dws_gwsv4 = [
      #intermediate_c, c, n, s, group_size_dws, group_pw, group_lin, el_add      
      #group_size_dws: Group size for depth wise separable
      #group_pw: Group size for pw 1x1 layer
      #group_lin: Group size for lin 1x1 layer
      #el_add: true:use element wise addition if it is free, false:don't use
      [32,      32, 1, s[0],    1, 1, 1, 1],                 #blk_1
      [32,      32, 1, 1,      32, 1, 1, 0],                 #blk_2_1
      [96,      62, 1, s[1],   48, 1, 1, 1],                 #blk_3_1
      [62*2,    62, 1, 1,      62, 1, 1, 1],                 #blk_3_2
      [62*3,    62, 3, s[2],   62, 1, 1, 1],                 #blk_4_1, 4_2, 4_3
      [62*3,    62, 1, s[3],   62, 1, 1, 1],                 #blk_5_1
      [62*6,    62, 3, 1,      62, 1, 1, 1],                 #5_2, 5_3, 5_4
      [62*6,  62*2, 1, 1,      62, 1, 1, 1],                 #blk_6_1
      [62*10, 62*2, 2, 1,      31, 1, 1, 1],                 #blk_6_2, blk_6_3
      [62*10, 62*4, 1, s[4],   31, 1, 1, 1],                 #blk_7_1
      [62*16, 62*4, 2, 1,      31, 1, 1, 1],                 #blk_7_2, blk_7_3
      [62*16, 62*8, 1, 1,      31, 1, 1, 1],                 #blk8_1
      [62*5,  1280, 1, 1,       1, 1, 1, 1],
    ]

    flr = lambda a : (a//model_config.group_size_dws)*model_config.group_size_dws
    enc_dec_dws_ratio_lcm = int(np.lcm(model_config.group_size_dws, 4))
    flr_lcm = lambda a: (a // enc_dec_dws_ratio_lcm) * enc_dec_dws_ratio_lcm
    
    group_size_dws = model_config.group_size_dws
    #block8 ch will be used by decoder which expcects ch to be multiple of 4
    b8_inter_c = flr_lcm(64 * 15)
    b8_c = flr_lcm(64 * 5)

    mobv2_gws_align = [
      #intermediate_c, c, n, s, group_size_dws, group_pw, group_lin, el_add      
      #group_size_dws: Group size for depth wise separable
      #group_pw: Group size for pw 1x1 layer
      #group_lin: Group size for lin 1x1 layer
      #el_add: true:use element wise addition if it is free, false:don't use
      [   flr(64),   flr(64), 1, s[0], 1,              1, 1, 1],                  #blk_1
      [   flr(64),   flr(64), 1, s[1], group_size_dws, 1, 1, 1],                  #blk_2_1
      [ flr(64*4),   flr(64), 2, s[2], group_size_dws, 1, 1, 1],                  #blk_4_1, 4_2
      [ flr(64*6),   flr(64), 4, s[3], group_size_dws, 1, 1, 1],                  #blk_5_1, 5_2, 5_3, 5_4
      [ flr(64*4), flr(64*2), 1, 1,    group_size_dws, 1, 1, 1],                  #blk_6_1
      [ flr(64*8), flr(64*2), 2, 1,    group_size_dws, 1, 1, 1],                  #blk_6_2, 6_3
      [flr(64*12), flr(64*2), 3, s[4], group_size_dws, 1, 1, 1],                  #blk_7_1, 7_2, 7_3
      [b8_inter_c,      b8_c, 1, 1,    group_size_dws, 1, 1, 1],                  #blk8_1
      [ flr(64*5),      1280, 1, 1,    1,              1, 1, 1],                  #blk9
    ]
    
          # 0,  1,  2,  3,  4,  5,  6,  7,  8,   9,  10, 11,   12, 13,   14,  15,  16
    chs = [16, 24, 32, 40, 48, 56, 64, 72, 96, 112, 128, 160, 176, 192, 320, 352, 384]
    n_ch_last = 1280
    #expansion ratio
    t = model_config.expand_ratio
    proxyless_for_latency_measure = [
      #intermediate_c, c, n, s, group_size_dws, group_pw, group_lin, el_add      
      #group_size_dws: Group size for depth wise separable
      #group_pw: Group size for pw 1x1 layer
      #group_lin: Group size for lin 1x1 layer
      #el_add: true:use element wise addition if it is free, false:don't use

      #stage1
      [  chs[0]*t,    chs[0], 1, s[0], 1,                1, 1, 1],
      
      #stage2
      [  chs[0]*t,    chs[1], 1, s[1], group_size_dws,  1, 1, 1],
      [  chs[1]*t,    chs[1], 1, s[1], group_size_dws,  1, 1, 1],
      
      #stage3
      [  chs[1]*t,    chs[2], 1, s[2], group_size_dws,  1, 1, 1],
      [  chs[2]*t,    chs[3], 1,    1, group_size_dws,  1, 1, 1],
      [  chs[3]*t,    chs[4], 1,    1, group_size_dws,  1, 1, 1],
      
      #stage4
      [  chs[4]*t,    chs[5], 1, s[3], group_size_dws,  1, 1, 1],
      [  chs[5]*t,    chs[6], 1,    1, group_size_dws,  1, 1, 1],
      [  chs[6]*t,    chs[7], 1,    1, group_size_dws,  1, 1, 1],
      [  chs[7]*t,    chs[8], 1,    1, group_size_dws,  1, 1, 1],
      [  chs[8]*t,    chs[9], 1,    1, group_size_dws,  1, 1, 1],
      [  chs[9]*t,   chs[10], 1,    1, group_size_dws,  1, 1, 1],

      #stage5
      [ chs[10]*t,   chs[11], 1, s[4], group_size_dws,  1, 1, 1],
      [ chs[11]*t,   chs[12], 1,    1, group_size_dws,  1, 1, 1],
      [ chs[12]*t,   chs[13], 1,    1, group_size_dws,  1, 1, 1],
      [ chs[13]*t,   chs[14], 1,    1, group_size_dws,  1, 1, 1],
      [ chs[14]*t,   chs[15], 1,    1, group_size_dws,  1, 1, 1],
      [ chs[15]*t,   chs[16], 1,    1, group_size_dws,  1, 1, 1],
      [   chs[16], n_ch_last, 1,    1,               1,  1, 1, 1],
    ]

    base_type = 'proxyless'
    if base_type == 'mobv3':      
              # 0,  1,  2,  3,  4,  5,  6,  7,  8,   9,  10,  11
        chs = [24, 24, 32, 32, 48, 48, 96, 96, 136, 136, 192, 192]
        n_ch_last = 1152
    elif base_type == 'proxyless':
              # 0,  1,  2,  3,  4,  5,  6,    7,   8,   9,  10,  11
        chs = [40, 24, 32, 32, 56, 56, 104, 104, 128, 128, 248, 248]
        n_ch_last = 1488
    else:
        exit("unsuported")    

    #expansion ratio
    t = model_config.expand_ratio
    once_for_all_for_latency_measure = [
      #intermediate_c, c, n, s, group_size_dws, group_pw, group_lin, el_add      
      #group_size_dws: Group size for depth wise separable
      #group_pw: Group size for pw 1x1 layer
      #group_lin: Group size for lin 1x1 layer
      #el_add: true:use element wise addition if it is free, false:don't use

      #stage1
      [  chs[0]*t,    chs[0], 1, s[0], 1,                1, 1, 1],
      [  chs[0],      chs[1], 1,    1, group_size_dws,  1, 1, 1],
      
      #stage2
      [  chs[1]*t,    chs[2], 1, s[1], group_size_dws,  1, 1, 1],
      [  chs[2]*t,    chs[3], 1,    1, group_size_dws,  1, 1, 1],
      
      #stage3
      [  chs[3]*t,    chs[4], 1, s[2], group_size_dws,  1, 1, 1],
      [  chs[4]*t,    chs[5], 1,    1, group_size_dws,  1, 1, 1],

      #stage4
      [  chs[5]*t,    chs[6], 1, s[3], group_size_dws,  1, 1, 1],
      [  chs[6]*t,    chs[7], 1,    1, group_size_dws,  1, 1, 1],
      [  chs[7]*t,    chs[8], 1,    1, group_size_dws,  1, 1, 1],
      [  chs[8]*t,    chs[9], 1,    1, group_size_dws,  1, 1, 1],

      #stage5
      [  chs[9]*t,   chs[10], 1, s[4], group_size_dws,  1, 1, 1],
      [ chs[10]*t,   chs[11], 1,    1, group_size_dws,  1, 1, 1],
    
      #last ch
      [   chs[11], n_ch_last, 1,    1,               1,  1, 1, 1],
    ]

    if nw_type == 'mobv2_orig':
        nw_config = mobv2_orig
    elif nw_type == 'mobv2_dws_gwsv2':
        nw_config = mobv2_dws_gwsv2
    elif nw_type == 'mobv2_dws_gwsv3':
        nw_config = mobv2_dws_gwsv3
    elif nw_type == 'mobv2_dws_gwsv4':
        nw_config = mobv2_dws_gwsv4
    elif nw_type == 'mobv2_gws_align':
        nw_config = mobv2_gws_align
    elif nw_type == 'proxyless_for_latency_measure':
        nw_config = proxyless_for_latency_measure
    elif nw_type == 'once_for_all_for_latency_measure':
        nw_config = once_for_all_for_latency_measure
    else:
       exit("wrong nw name")

    return nw_config 


###################################################
class InvertedResidual(torch.nn.Module):
    def __init__(self, inp, oup, stride, hidden_dim, activation, kernel_size, 
                group_size_dws, group_pw, group_lin, el_add, linear_dw):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        self.use_res_connect = self.stride == 1 and inp == oup
        activation_dw = (False if linear_dw else activation)
        groups_dw = hidden_dim//group_size_dws

        layers = []
        if hidden_dim != inp:
            # pw
            layers.append(xnn.layers.ConvNormAct2d(inp, hidden_dim, kernel_size=1, activation=activation, groups=group_pw))

        layers.extend([
            # dw
            xnn.layers.ConvNormAct2d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=stride, activation=activation_dw, groups=groups_dw),
            # pw-linear
            torch.nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False, groups=group_lin),
            xnn.layers.DefaultNorm2d(oup),
        ])

        if linear_dw:
            layers.append(activation(inplace=True))

        self.conv = torch.nn.Sequential(*layers)

        if self.use_res_connect:
            self.add = xnn.layers.AddBlock(signed=True)

    def forward(self, x):
        x1 = self.conv(x)
        if self.use_res_connect:
            x1 = self.add((x, x1))

        return x1


###################################################
class MobileNetV2TVGWSBase(torch.nn.Module):
    def __init__(self, ResidualBlock, model_config):
        super().__init__()
        self.model_config = model_config
        self.num_classes = self.model_config.num_classes
        strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2)

        if self.model_config.layer_setting is None:
            self.model_config.layer_setting = define_nw(nw_type=model_config.network_type, s=strides, model_config = self.model_config)

        stride_first = self.model_config.layer_setting[0][3]
        input_channel = self.model_config.layer_setting[0][1]
        last_channel = self.model_config.layer_setting[-1][1]

        activation = self.model_config.activation
        width_mult = self.model_config.width_mult
        linear_dw = self.model_config.linear_dw
        kernel_size = self.model_config.kernel_size

        # building first layer
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * max(1.0, width_mult))
        features = [xnn.layers.ConvNormAct2d(3, input_channel, kernel_size=kernel_size, stride=stride_first, activation=activation)]

        # building inverted residual blocks
        for intermediate_c, c, n, s, group_size_dws, group_pw, group_lin, el_add in self.model_config.layer_setting[1:-1]:
            output_channel = int(c * width_mult)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(ResidualBlock(input_channel, output_channel, stride, intermediate_c, activation, kernel_size,
                                group_size_dws, group_pw, group_lin, el_add, linear_dw))
                input_channel = output_channel
            #
        #

        # building classifier
        if self.model_config.num_classes != None:
            # building last several layers
            features.append(xnn.layers.ConvNormAct2d(input_channel, self.last_channel, kernel_size=1, activation=activation))

            self.classifier = torch.nn.Sequential(
                torch.nn.Dropout(0.2) if self.model_config.dropout else xnn.layers.BypassBlock(),
                torch.nn.Linear(self.last_channel, self.num_classes),
            )


        # make it nn.Sequential
        self.features = torch.nn.Sequential(*features)

        # weight initialization
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)
            elif isinstance(m, torch.nn.BatchNorm2d):
                torch.nn.init.ones_(m.weight)
                torch.nn.init.zeros_(m.bias)
            elif isinstance(m, torch.nn.Linear):
                torch.nn.init.normal_(m.weight, 0, 0.01)
                torch.nn.init.zeros_(m.bias)


    def forward(self, x):
        x = self.features(x)
        if self.num_classes is not None:
            xnn.utils.print_once('=> feature size is: ', x.size())
            x = torch.nn.functional.adaptive_avg_pool2d(x,(1,1))
            x = torch.flatten(x, 1)
            x = self.classifier(x)

        return x

class MobileNetV2TVGWS(MobileNetV2TVGWSBase):
    def __init__(self, **kwargs):
        model_config = get_config_mnetv2_gws()
        if 'model_config' in list(kwargs.keys()):
            model_config = model_config.merge_from(kwargs['model_config'])
        super().__init__(InvertedResidual, model_config)


#######################################################################
class mobilenet_v2_tv_gws(MobileNetV2TVGWS):
    def __init__(self, pretrained=False, **kwargs):
        super().__init__(**kwargs)
        if pretrained:
            self = xnn.utils.load_weights(self, pretrained)



